package com.skyline.energy.jdbc.datasource;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;

import javax.sql.DataSource;

import com.skyline.common.utils.Assert;
import com.skyline.energy.exception.DataAccessException;
import com.skyline.energy.jdbc.impl.SqlThreadLocal;

/**
 * 主从数模式据源
 * 
 * 
 * @author wuqh
 * 
 */
public class MasterSlaverDataSource extends AbstractDistributeDataSource implements DataSource {
	private List<DataSource> masters = Collections.emptyList();

	private List<DataSource> slavers = Collections.emptyList();

	private Random random = new Random();

	public void setMaster(DataSource master) {
		this.masters = new ArrayList<DataSource>();
		this.masters.add(master);
	}

	public void setMasters(List<DataSource> masters) {
		this.masters = new ArrayList<DataSource>(masters);
	}

	public void setSlavers(List<DataSource> slavers) {
		this.slavers = new ArrayList<DataSource>(slavers);
	}

	protected DataSource getDataSource() {
		SqlThreadLocal local = SqlThreadLocal.get();
		Assert.notNull(local, "系统BUG，请联系作者");

		boolean write = false;
		if (local.isWriteType()) {
			write = true;
		}

		DataSource dataSource;
		if (write) {
			dataSource = randomGet(masters);
		} else {
			dataSource = randomGet(slavers);
		}

		if (dataSource == null) {
			throw new DataAccessException("无法获取" + (write ? "主" : "从") + "数据源，执行SQL[" + local.getSql() + "]");
		}

		return dataSource;
	}

	protected DataSource randomGet(List<DataSource> dataSources) {
		if (dataSources.size() == 0) {
			return null;
		}
		int index = random.nextInt(dataSources.size()); // 0.. size
		return dataSources.get(index);
	}
}
